#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.

#  Simple script to generate test cases for the torchao ops
from string import Template


def add_test_string(kernel, m, n, k, g, has_bias, has_clamp):
    name = f"m{m}xn{n}xk{k}xg{g}{'_bias' if has_bias else ''}{'_clamp' if has_clamp else ''}"
    d = {
        "name": name,
        "kernel": kernel,
        "m": m,
        "n": n,
        "k": k,
        "g": g,
        "has_bias": "true" if has_bias else "false",
        "has_clamp": "true" if has_clamp else "false",
    }

    test_template = Template(
        """
TEST(test_linear_8bit_act_xbit_weight, Kleidi_${kernel}_${name}) {
  UKernelConfig ukernel_config = get_ukernel_config_kleidi<${kernel}>();
  test_linear_8bit_act_xbit_weight<
      4 /*weight_nbit*/,
      false /*has_weight_zeros*/,
      $has_bias /*has_bias*/,
      $has_clamp /*has_clamp*/,
      true /*has_kleidi*/>(
      /*m=*/$m, /*n=*/$n, /*k=*/$k, /*group_size=*/$g, &ukernel_config);
}
"""
    )

    return [test_template.safe_substitute(d)]


def get_test_block(kernel):
    # Assuming given kleidi kernel can run with all these test cases
    tests = []
    # GEMV, m == 1
    ## subtile
    tests += add_test_string(kernel, 1, 2 * 1, 32, 32, False, False)
    tests += add_test_string(kernel, 1, 2 * 2, 32, 32, False, False)
    tests += add_test_string(kernel, 1, 2 * 3, 32, 32, True, False)
    tests += add_test_string(kernel, 1, 2 * 2, 32, 32, True, True)
    tests += add_test_string(kernel, 1, 2 * 3, 32, 32, False, True)
    ## larger: n - must be multiple of 2
    tests += add_test_string(kernel, 1, 2 * 11, 32, 32, False, False)
    tests += add_test_string(kernel, 1, 2 * 13, 32, 32, True, False)
    tests += add_test_string(kernel, 1, 2 * 51, 32, 32, False, True)
    tests += add_test_string(kernel, 1, 2 * 111, 32, 32, False, False)
    ## larger: n (odd)
    tests += add_test_string(kernel, 1, 11, 32, 32, False, False)
    tests += add_test_string(kernel, 1, 13, 32, 32, True, False)
    tests += add_test_string(kernel, 1, 51, 32, 32, False, True)
    tests += add_test_string(kernel, 1, 111, 32, 32, False, False)
    ## larger: k, g - must be multiple of 32
    tests += add_test_string(kernel, 1, 2 * 7, 64, 32, False, False)
    tests += add_test_string(kernel, 1, 2 * 11, 128, 32, True, False)
    tests += add_test_string(kernel, 1, 2 * 13, 64, 64, False, True)
    tests += add_test_string(kernel, 1, 2 * 17, 128, 64, False, False)

    # GEMM, m > 1
    ## subtile
    tests += add_test_string(kernel, 2, 2 * 1, 32, 32, False, False)
    tests += add_test_string(kernel, 2, 2 * 2, 32, 32, False, False)
    tests += add_test_string(kernel, 3, 2 * 3, 32, 32, True, False)
    tests += add_test_string(kernel, 4, 2 * 4, 32, 32, True, True)
    tests += add_test_string(kernel, 3, 2 * 3, 32, 32, False, True)
    ## larger: m
    tests += add_test_string(kernel, 31, 2 * 1, 32, 32, False, False)
    tests += add_test_string(kernel, 32, 2 * 2, 32, 32, False, False)
    tests += add_test_string(kernel, 33, 2 * 3, 32, 32, True, False)
    tests += add_test_string(kernel, 34, 2 * 4, 32, 32, True, True)
    tests += add_test_string(kernel, 35, 2 * 3, 32, 32, False, True)
    ## larger: n - must be multiple of 2
    tests += add_test_string(kernel, 7, 2 * 11, 32, 32, False, False)
    tests += add_test_string(kernel, 17, 2 * 13, 32, 32, True, False)
    tests += add_test_string(kernel, 23, 2 * 51, 32, 32, False, True)
    tests += add_test_string(kernel, 41, 2 * 111, 32, 32, False, False)
    ## larger: n (odd)
    tests += add_test_string(kernel, 7, 11, 32, 32, False, False)
    tests += add_test_string(kernel, 17, 13, 32, 32, True, False)
    tests += add_test_string(kernel, 23, 51, 32, 32, False, True)
    tests += add_test_string(kernel, 41, 111, 32, 32, False, False)
    ## larger: k, g - must be multiple of 32
    tests += add_test_string(kernel, 19, 2 * 7, 64, 32, False, False)
    tests += add_test_string(kernel, 23, 2 * 11, 128, 32, True, False)
    tests += add_test_string(kernel, 29, 2 * 13, 64, 64, False, True)
    tests += add_test_string(kernel, 101, 2 * 17, 128, 64, False, False)

    return "".join(tests)


def main():
    kleidi_template = Template(
        """
/*****************/
// ${kernel} tests 
/*****************/
${prologue}
${tests}
${epilogue}
"""
    )

    kleidi_kernels = [
        "dotprod_1x4x32",
        "dotprod_1x8x32",
        "i8mm_4x8x32",
        "i8mm_8x4x32",
    ]

    print("/* Generated by generate_tests.py */")
    print("/* Do not modify */")
    print()
    print("#if defined(TORCHAO_ENABLE_KLEIDI)")
    for kernel in kleidi_kernels:
        prologue, epilogue = "", ""
        if "i8mm" in kernel:
            prologue = "#if defined(TORCHAO_ENABLE_ARM_I8MM)"
            epilogue = "#endif // TORCHAO_ENABLE_ARM_I8MM"
        tests = get_test_block(kernel)
        d = {
            "prologue": prologue,
            "kernel": kernel,
            "tests": tests,
            "epilogue": epilogue,
        }

        print(kleidi_template.safe_substitute(d))
    print("#endif // TORCHAO_ENABLE_KLEIDI")


if __name__ == "__main__":
    main()
